Skip to content

[New Feasture]: Support SP + DP parallal on Wan training#1223

Open
mahaocong90 wants to merge 4 commits intomodelscope:mainfrom
mahaocong90:dev-support-sequence-parallal-for-wan-training
Open

[New Feasture]: Support SP + DP parallal on Wan training#1223
mahaocong90 wants to merge 4 commits intomodelscope:mainfrom
mahaocong90:dev-support-sequence-parallal-for-wan-training

Conversation

@mahaocong90
Copy link

This PR adds SP + DP support for WAN training through the USP interface.

Description
Added new configuration parameters --sp_size. USP will be enabled when --sp_size > 1. This is an example of enabling sequence parallelism, running on one node with 8 GPUs, SP = 4, DP = 2.

accelerate launch examples/wanvideo/model_training/train.py \
  --dataset_base_path VidData/video \
  --dataset_metadata_path VidData/data/train/metadata.csv \
  --height 480 \
  --width 832 \
  --num_frames 49 \
  --dataset_repeat 1 \
  --model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \
  --learning_rate 1e-5 \
  --num_epochs 2 \:
  --sp_size 4 \
  --remove_prefix_in_ckpt "pipe.dit." \
  --output_path "./models/train/Wan2.2-TI2V-5B_full" \
  --trainable_models "dit" \
  --extra_inputs "input_image"

Environment version

  • os:ubuntu24.04
  • cuda driver:550.163.01 + 12.9
  • Python:3.12.3
  • torch:2.8.0
  • xfuser:0.4.5
  • transformers:4.55.2
  • gpu: a800 x 8, one node

Loss Accuracy
I used wan2.2 ti2v-5B model for debugging. Functional testing was performed using the Databoost/VidData dataset from Hugging Face, which contains 1,006 MP4 video files and a CSV description file, conforming to the diffsynth dataset format.
To ensure that sequence parallelism (sp) does not affect the loss outcome, I aligned the loss between dp=2 + sp=4 and dp=2, to make sure the number of gradient accumulation steps for weight updates was identical. Then, I fixed the random seeds for sampler, noise, and timestep generation to guarantee identical inputs.

The launch parameters for the dp=2 configuration :

accelerate launch --num_processes 2 examples/wanvideo/model_training/train.py \
  --dataset_base_path VidData/video \
  --dataset_metadata_path VidData/data/train/metadata.csv \
  --height 480 \
  --width 832 \
  --num_frames 49 \
  --dataset_repeat 1 \
  --model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \
  --learning_rate 1e-5 \
  --num_epochs 2 \
  --remove_prefix_in_ckpt "pipe.dit." \
  --output_path "./models/train/Wan2.2-TI2V-5B_full" \
  --trainable_models "dit" \
  --extra_inputs "input_image"

Forward process loss comparison (Steps 0-100):
image
Backward process losscomparison (Epoch=2, global steps=1006):
image

During debugging of the backward pass, I ran into two issues related to xdit. The latest release of the base library seemed not to have merged the PRs addressing these fixes. Hence, I manually patched xdit in my debugging environment to guarantee correct backward functionality.

Issue for xDit
Issue1 is related to PR xdit-project/xDiT#598
"xFuserRingFlashAttnFunc has 17 inputs (including ctx), but it inherits the backward() method from RingFlashAttnFunc which only returns 16 values (3 gradients + 13 Nones)!"

Issue2: The all_gather interface in xdit utilizes torch’s all_gather_into_tensor interface. As for the torch 2.9 release version, this interface still does not provide a backward method and cannot support automatic autograd. The commit in the torch community (pytorch/pytorch#168140) has not yet been merged. Therefore, a simple replacement of all_gather_into_tensor in xdit with torch.distributed.nn.functional.all_gather is applied here to enable autograd support.

Additionally, I have only tested Wan 2.2, Ti2V-5B and a small dataset on my development environment. The results appear to align with expectations. I will conduct further tests using other Wan models to confirm that there is no impact on other core functionalities.

Thanks

… DP + SP grouping configuration. The computational logic follows: DP = world_size / SP.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @mahaocong90, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the distributed training capabilities for WAN models by introducing support for combined Sequence Parallelism (SP) and Data Parallelism (DP). The changes enable users to specify a sequence parallelism size, which then configures the training pipeline, data loading, and loss calculation to operate efficiently across multiple GPUs, distributing both data batches and sequence elements. This aims to improve training scalability and performance, particularly for large models and long sequences.

Highlights

  • Sequence Parallelism Integration: Introduced the --sp_size argument and integrated Sequence Parallelism (SP) with Data Parallelism (DP) for WAN training via the USP interface, enabling more efficient distributed training.
  • Distributed Data Handling: Modified the FlowMatchSFTLoss function to ensure consistent random noise and timestep generation across sequence parallel ranks by broadcasting these values from rank 0 within the SP group.
  • Flexible Dataloader for Parallelism: Implemented a new build_dataloader function to correctly configure DistributedSampler for combined DP and SP setups, guaranteeing that each SP group receives identical data samples.
  • Pipeline Configuration Updates: Updated the WanVideoPipeline constructor and its from_pretrained method to accept and utilize sp_size, facilitating proper Unified Sequence Parallelism (USP) setup.
  • Enhanced Distributed Environment Setup: Improved the initialize_usp function to dynamically calculate and set data_parallel_degree and sequence_parallel_degree based on the provided sp_size and the total world size.
  • Improved Training Loop Logging: Added detailed timing metrics (forward, backward, optimizer, step) and a more informative progress bar to the training loop in runner.py for better debugging and performance monitoring.
  • xDiT Compatibility Notes: Included comments in relevant files (wan_video.py, xdit_context_parallel.py) to document known xDiT issues regarding xFuserRingFlashAttnFunc's backward method and torch.distributed.all_gather_into_tensor's lack of autograd support, noting that manual patches were used for debugging.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Sequence Parallelism (SP) combined with Data Parallelism (DP) for Wan training. The changes primarily involve adding and propagating the sp_size configuration to correctly set up the distributed environment and data loaders. The approach to ensure data consistency across SP ranks by synchronizing random noise and timesteps is sound. However, I've identified a critical issue in the build_dataloader function concerning worker initialization that will cause a runtime error. Additionally, there are several medium-severity issues, including code duplication and leftover debugging print statements, which should be addressed to enhance code quality and maintainability.

Comment on lines 49 to 59
def worker_seed_init(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

dataloader_kwargs = dict(
dataset=dataset,
sampler=sampler,
num_workers=num_workers,
pin_memory=True,
worker_init_fn=worker_seed_init,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There are a few critical issues with the worker_seed_init function and its usage that will lead to runtime errors and incorrect behavior:

  1. Missing Imports: The random and np (numpy) modules are used but not imported in this file, which will cause a NameError.
  2. Incorrect Signature: The worker_init_fn for a DataLoader receives the worker_id as an argument. The function signature should be def worker_init_fn(worker_id):.
  3. Unused Seed: The seed parameter passed to build_dataloader is not being used to seed the workers. Each worker should be seeded differently based on a base seed and its ID to ensure reproducibility.

Here's a suggested fix that addresses these points. Please also remember to add import random and import numpy as np at the top of the file.

Suggested change
def worker_seed_init(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
dataloader_kwargs = dict(
dataset=dataset,
sampler=sampler,
num_workers=num_workers,
pin_memory=True,
worker_init_fn=worker_seed_init,
def worker_init_fn(worker_id):
worker_seed = seed + worker_id
random.seed(worker_seed)
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
dataloader_kwargs = dict(
dataset=dataset,
sampler=sampler,
num_workers=num_workers,
pin_memory=True,
worker_init_fn=worker_init_fn,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems unnecessary to specify worker_init_fn during dataloader initialization. It was only used to fix randomness when aligning loss precision, so I’ve removed it for now.

Comment on lines 37 to 39
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variables max_timestep_boundary and min_timestep_boundary are already defined on lines 11-12. This recalculation is redundant and can be removed to improve code clarity and reduce duplication.

Comment on lines 37 to 45
print(f"accelerator.processid={rank}, accelerator.num_processes={world_size}, "
f"sp_size={sp_size}, dp_size={dp_size}, dp_rank={dp_rank}")
else:
if accelerator is not None:
dp_size = accelerator.num_processes
dp_rank = accelerator.process_index
else:
raise ValueError(f"Accelerator is None.")
print(f"dp_size={dp_size}, dp_rank={dp_rank}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These print statements appear to be for debugging purposes. It's recommended to remove them or replace them with a proper logging framework (e.g., logging.debug(...)) to control verbosity and keep the standard output clean, especially in non-debug runs.

@toyot-li
Copy link

Issue2: The all_gather interface in xdit utilizes torch’s all_gather_into_tensor interface. As for the torch 2.9 release version, this interface still does not provide a backward method and cannot support automatic autograd. The commit in the torch community (pytorch/pytorch#168140) has not yet been merged. Therefore, a simple replacement of all_gather_into_tensor in xdit with torch.distributed.nn.functional.all_gather is applied here to enable autograd support.

@mahaocong90 Could you please specify how to solve this issue in detail, e.g., which files in xdit are to be modified? Thanks.

@mahaocong90
Copy link
Author

mahaocong90 commented Feb 12, 2026

Issue2: The all_gather interface in xdit utilizes torch’s all_gather_into_tensor interface. As for the torch 2.9 release version, this interface still does not provide a backward method and cannot support automatic autograd. The commit in the torch community (pytorch/pytorch#168140) has not yet been merged. Therefore, a simple replacement of all_gather_into_tensor in xdit with torch.distributed.nn.functional.all_gather is applied here to enable autograd support.

@mahaocong90 Could you please specify how to solve this issue in detail, e.g., which files in xdit are to be modified? Thanks.

Yes, My fix is like this:
xfuser version: 0.4.5
On /usr/local/lib/python3.12/dist-packages/xfuser/core/distributed/group_coordinator.py

   def all_gather(
        self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False
   ) -> Union[torch.Tensor, List[torch.Tensor]]:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        ...
        #  Replacement all_gather_into_tensor to torch.distributed.nn.functional.all_gather.
        # All-gather.
        -torch.distributed.all_gather_into_tensor(
        -    output_tensor, input_, group=self.device_group
        -)
        +gathered_list = torch.distributed.nn.functional.all_gather(input_, group=self.device_group)
        +output_tensor = torch.cat(gathered_list, dim=0)

This fix works well on my PyTorch 2.8.0 environment.

@toyot-li
Copy link

toyot-li commented Feb 12, 2026

@mahaocong90 Thanks for your prompt response!

Before fixing this issue, there are warnings as follows

/root/anaconda3/envs/diffsynth/lib/python3.11/site-packages/torch/autograd/graph.py:865: UserWarning: c10d::allgather_base: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:76.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass

Following your solution, the above warning disappears during training. It works!

PS: the same xfuser version, with pytorch version 2.10.0+cu128

@mahaocong90
Copy link
Author

mahaocong90 commented Feb 13, 2026

@mahaocong90 Thanks for your prompt response!

Before fixing this issue, there are warnings as follows

/root/anaconda3/envs/diffsynth/lib/python3.11/site-packages/torch/autograd/graph.py:865: UserWarning: c10d::allgather_base: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:76.) return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass

Following your solution, the above warning disappears during training. It works!

PS: the same xfuser version, with pytorch version 2.10.0+cu128

Additionally, my PR still has some issues. Testing revealed that errors still occur when using DeepSpeed. This might be because, in order to ensure that all SP ranks within the SP group get the same sampler when using USP, I removed the dataloader wrapper applied by Accelerate’s prepare method.

def launch_training_task(
    accelerator: Accelerator,
    dataset: torch.utils.data.Dataset,
    model: DiffusionTrainingModule,
    model_logger: ModelLogger,
    learning_rate: float = 1e-5,
    weight_decay: float = 1e-2,
    num_workers: int = 1,
    save_steps: int = None,
    num_epochs: int = 1,
    sp_size: int = 1,
    args = None,
):
    ...
    optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
    -dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
    -model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
    +dataloader = build_dataloader(accelerator, dataset, num_workers, sp_size)
    +model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)

However, this causes DeepSpeed to raise an error with an empty message:

-> if deepspeed_plugin.is_auto("train_micro_batch_size_per_gpu"):
(Pdb) c
[rank0]: Traceback (most recent call last):
[rank0]:   File "/pfs/mahaocong/DiffSynth-Studio/examples/wanvideo/model_training/train.py", line 208, in <module>
[rank0]:     launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
[rank0]:   File "/pfs/mahaocong/DiffSynth-Studio/diffsynth/diffusion/runner.py", line 96, in launch_training_task
[rank0]:     model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)
[rank0]:                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/accelerate/accelerator.py", line 1547, in prepare
[rank0]:     result = self._prepare_deepspeed(*args)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/accelerate/accelerator.py", line 2110, in _prepare_deepspeed
[rank0]:     raise ValueError(
[rank0]: ValueError: When using DeepSpeed, `accelerate.prepare()` requires you to pass at least one of training or evaluation dataloaders with `batch_size` attribute returning an integer value or alternatively set an integer value in `train_micro_batch_size_per_gpu` in the deepspeed config file or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`.
[rank0]:[W213 09:44:59.058929999 ProcessGroupNCCL.cpp:1505] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())```

It appears that manual configuration is necessary after Accelerate initialization,

if __name__ == "__main__":
    parser = wan_parser()
    args = parser.parse_args()
    accelerator = accelerate.Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
    )
    +accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = 1

I am trying to fix this issue.

@toyot-li
Copy link

Oh, I noticed the differences you mentioned.

  1. I didn't remove dataloader from accelerator.prepare. Is it a must to exclude dataloader from acclerate.prepare?
  2. I use accelerate launch train.py --args for launching, so it does not trigger the deepspeed-related error?

@mahaocong90
Copy link
Author

Oh, I noticed the differences you mentioned.

  1. I didn't remove dataloader from accelerator.prepare. Is it a must to exclude dataloader from acclerate.prepare?
  2. I use accelerate launch train.py --args for launching, so it does not trigger the deepspeed-related error?

Yes, I think it is necessary to remove Accelerate’s wrapping of the dataloader when using USP. After Accelerate wraps the dataloader, the sampler get iteration like follows (My accelerate version is 1.10.1):

case1. DataLoaderConfiguration parameters is split_batches=False, dispatch_batches=False
In this case, BatchSamplerShard’s _iter_with_no_split method distributes batches to each rank sequentially based on if idx % self.num_processes == self.process_index (e.g., for a dataset [1, 2, 3, 4, 5, 6] with two ranks: rank0 receives [1, 3, 5], rank1 receives [2, 4, 6]).

    def _iter_with_no_split(self):
        initial_data = []
        batch_to_yield = []
        for idx, batch in enumerate(self.batch_sampler):
            # We gather the initial indices in case we need to circle back at the end.
            if not self.drop_last and idx < self.num_processes:
                initial_data += batch
            # We identify the batch to yield but wait until we ar sure every process gets a full batch before actually
            # yielding it.
            if idx % self.num_processes == self.process_index:
                batch_to_yield = batch
            if idx % self.num_processes == self.num_processes - 1 and (
                self.batch_size is None or len(batch) == self.batch_size
            ):
                yield batch_to_yield
                batch_to_yield = []

case2. DataLoaderConfiguration parameters is split_batches=False, dispatch_batches=True
Here, BatchSamplerShard has rank0 fetch batches from the dataloader and dispatch them to other ranks. Whether the dispatched data is replicated depends on whether Accelerate has created a model parallelism group (currently, Accelerate only supports TP parallelism):
If a TP parallelism group is created: Rank0 fetches one batch from the dataloader, replicates it, and distributes identical copies to all ranks.
If no TP parallelism group exists: Rank0 fetches a batch sequence from the dataloader and distributes distinct batches sequentially to other ranks.

    def _fetch_batches(self, iterator):
    ...
                try:
                # for TP case avoid using split_batches
                # since it would mean that the dataloader should be spilling out
                # duplicates of batches.
                if self.split_batches:
                    # One batch of the main iterator is dispatched and split.
                    if self.submesh_tp:
                        logger.warning(
                            "Use of split_batches for TP would need the dataloader to produce duplicate batches,"
                            "otherwise, use dispatch_batches=True instead."
                        )
                    self._update_state_dict()
                    batch = next(iterator)
                else:
                    # num_processes batches of the main iterator are concatenated then dispatched and split.
                    # We add the batches one by one so we have the remainder available when drop_last=False.
                    batches = []
                    if self.submesh_tp:
                        # when tp, extract single batch and then replicate
                        self._update_state_dict()
                        batch = next(iterator)
                        batches = [batch] * self.state.num_processes
                    else:
                        for _ in range(self.state.num_processes):
                            self._update_state_dict()
                            batches.append(next(iterator))
                    try:
                        batch = concatenate(batches, dim=0)

So it seems that the accelerate dataloader wrap is not suitable for sp parallelism.

You can print the prompt associated with the training video to check whether the rank within the same sp group can get the same sample for each step.

def launch_training_task(
    ...
    for data in progress:
        print(f"[train] Rank{accelerator.process_index}, step{train_step}, prompt: {data['prompt']}")

@toyot-li
Copy link

Thanks for your detailed explanation!

In my case, I use the following code snippets:

    model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)

    for epoch_id in range(num_epochs):
        sampler.set_epoch(epoch_id)
        for data in tqdm(dataloader):
            with accelerator.accumulate(model):
                optimizer.zero_grad()
                if dataset.load_from_cache:
                    loss = model({}, inputs=data)
                else:
                    loss = model(data)
                    if dp_rank == 0 and (sp_size <= 1 or sp_rank == 0):
                        with open('loss_dp0.txt', 'a') as f:
                            f.write(str(loss.item()) + '\n')
                    if dp_rank == 1 and (sp_size <= 1 or sp_rank == 0):
                        with open('loss_dp1.txt', 'a') as f:
                            f.write(str(loss.item()) + '\n')
                accelerator.backward(loss)
                optimizer.step()
                model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
                scheduler.step()
        if save_steps is None:
            model_logger.on_epoch_end(accelerator, model, epoch_id)
    model_logger.on_training_end(accelerator, model, save_steps)

I still have two questions.

  1. I think there should be a sampler.set_epoch() operation.
  2. As above, I write loss into the local txt, but I find the losses are different between the following two launches.
CUDA_VISIBLE_DEVICES=0,1 accelerate launch examples/wanvideo/model_training/train.py \
  --dataset_base_path data/example_video_dataset \
  --dataset_metadata_path data/example_video_dataset/metadata_s2v.csv \
  --height 240 \
  --width 416 \
  --dataset_repeat 10 \
  --model_id_with_origin_paths "Wan-AI/Wan2.1-FLF2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-FLF2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
  --learning_rate 1e-4 \
  --num_epochs 1 \
  --remove_prefix_in_ckpt "pipe.dit." \
  --output_path "./models/train/Wan2.1-FLF2V-14B-720P_lora" \
  --lora_base_model "dit" \
  --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
  --lora_rank 32 \
  --extra_inputs "input_image,end_image" \
  --sp_size 1

and

accelerate launch examples/wanvideo/model_training/train.py \
  --dataset_base_path data/example_video_dataset \
  --dataset_metadata_path data/example_video_dataset/metadata_s2v.csv \
  --height 240 \
  --width 416 \
  --dataset_repeat 10 \
  --model_id_with_origin_paths "Wan-AI/Wan2.1-FLF2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-FLF2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
  --learning_rate 1e-4 \
  --num_epochs 1 \
  --remove_prefix_in_ckpt "pipe.dit." \
  --output_path "./models/train/Wan2.1-FLF2V-14B-720P_lora" \
  --lora_base_model "dit" \
  --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
  --lora_rank 32 \
  --extra_inputs "input_image,end_image" \
  --sp_size 4

I have already input a fixed seed to get_pipeline_inputs as follows

    def get_pipeline_inputs(self, data):
        inputs_posi = {"prompt": data["prompt"]}
        inputs_nega = {}
        inputs_shared = {
            # Assume you are using this pipeline for inference,
            # please fill in the input parameters.
            "input_video": data["video"],
            "height": data["video"][0].size[1],
            "width": data["video"][0].size[0],
            "num_frames": len(data["video"]),
            # Please do not modify the following parameters
            # unless you clearly know what this will cause.
            "cfg_scale": 1,
            "tiled": False,
            "rand_device": self.pipe.device,
            "use_gradient_checkpointing": self.use_gradient_checkpointing,
            "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
            "cfg_merge": False,
            "vace_scale": 1,
            "max_timestep_boundary": self.max_timestep_boundary,
            "min_timestep_boundary": self.min_timestep_boundary,
            "seed": 42
        }
        inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
        return inputs_shared, inputs_posi, inputs_nega

May I know how you align the losses as your figure indicates? Thanks.

@mahaocong90
Copy link
Author

mahaocong90 commented Feb 14, 2026

Thanks for your detailed explanation!

In my case, I use the following code snippets:

    model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)

    for epoch_id in range(num_epochs):
        sampler.set_epoch(epoch_id)
        for data in tqdm(dataloader):
            with accelerator.accumulate(model):
                optimizer.zero_grad()
                if dataset.load_from_cache:
                    loss = model({}, inputs=data)
                else:
                    loss = model(data)
                    if dp_rank == 0 and (sp_size <= 1 or sp_rank == 0):
                        with open('loss_dp0.txt', 'a') as f:
                            f.write(str(loss.item()) + '\n')
                    if dp_rank == 1 and (sp_size <= 1 or sp_rank == 0):
                        with open('loss_dp1.txt', 'a') as f:
                            f.write(str(loss.item()) + '\n')
                accelerator.backward(loss)
                optimizer.step()
                model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
                scheduler.step()
        if save_steps is None:
            model_logger.on_epoch_end(accelerator, model, epoch_id)
    model_logger.on_training_end(accelerator, model, save_steps)

I still have two questions.

  1. I think there should be a sampler.set_epoch() operation.
  2. As above, I write loss into the local txt, but I find the losses are different between the following two launches.
CUDA_VISIBLE_DEVICES=0,1 accelerate launch examples/wanvideo/model_training/train.py \
  --dataset_base_path data/example_video_dataset \
  --dataset_metadata_path data/example_video_dataset/metadata_s2v.csv \
  --height 240 \
  --width 416 \
  --dataset_repeat 10 \
  --model_id_with_origin_paths "Wan-AI/Wan2.1-FLF2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-FLF2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
  --learning_rate 1e-4 \
  --num_epochs 1 \
  --remove_prefix_in_ckpt "pipe.dit." \
  --output_path "./models/train/Wan2.1-FLF2V-14B-720P_lora" \
  --lora_base_model "dit" \
  --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
  --lora_rank 32 \
  --extra_inputs "input_image,end_image" \
  --sp_size 1

and

accelerate launch examples/wanvideo/model_training/train.py \
  --dataset_base_path data/example_video_dataset \
  --dataset_metadata_path data/example_video_dataset/metadata_s2v.csv \
  --height 240 \
  --width 416 \
  --dataset_repeat 10 \
  --model_id_with_origin_paths "Wan-AI/Wan2.1-FLF2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-FLF2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
  --learning_rate 1e-4 \
  --num_epochs 1 \
  --remove_prefix_in_ckpt "pipe.dit." \
  --output_path "./models/train/Wan2.1-FLF2V-14B-720P_lora" \
  --lora_base_model "dit" \
  --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
  --lora_rank 32 \
  --extra_inputs "input_image,end_image" \
  --sp_size 4

I have already input a fixed seed to get_pipeline_inputs as follows

    def get_pipeline_inputs(self, data):
        inputs_posi = {"prompt": data["prompt"]}
        inputs_nega = {}
        inputs_shared = {
            # Assume you are using this pipeline for inference,
            # please fill in the input parameters.
            "input_video": data["video"],
            "height": data["video"][0].size[1],
            "width": data["video"][0].size[0],
            "num_frames": len(data["video"]),
            # Please do not modify the following parameters
            # unless you clearly know what this will cause.
            "cfg_scale": 1,
            "tiled": False,
            "rand_device": self.pipe.device,
            "use_gradient_checkpointing": self.use_gradient_checkpointing,
            "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
            "cfg_merge": False,
            "vace_scale": 1,
            "max_timestep_boundary": self.max_timestep_boundary,
            "min_timestep_boundary": self.min_timestep_boundary,
            "seed": 42
        }
        inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
        return inputs_shared, inputs_posi, inputs_nega

May I know how you align the losses as your figure indicates? Thanks.

Ok, Let me introduce my work:
First, select an alignment target—I use the DP result as the baseline, e.g., dp=1 vs. dp=1 sp=2. When batch_size=1 and gradient_accumulation_steps=1, the loss computed with the same step should be equal or nearly identical.
Since the training process is divided into forward (fwd) and backward (bwd) phases, I start by comparing the fwd results:

            with accelerator.accumulate(model):
                optimizer.zero_grad()
                loss = model(data)
                #accelerator.backward(loss)
                #optimizer.step()
                #scheduler.step()

For the TI2V model, the fwd calculation involves the following steps: get data from the dataset, text embedding, vae encoder, dit, calculate the predicted noise and loss.

step1. Ensure that the input is consistent, that is, within the same sp group, each sp rank should get the same sample each time.

step2. Diffsynth generates random noise and timestep_id before caculate dit at each step, and thereby generates noisy input and modulation input. Therefore, it is also necessary to ensure that the noise and timestep of each sp rank are the same. I generate random values on sp rank0 and then broadcast them to other sp ranks:

        sp_group=get_sp_group()
        if get_sequence_parallel_rank() == 0:
            timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
            timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
        else:
            timestep = torch.zeros(1, dtype=pipe.torch_dtype, device=pipe.device)
        sp_group.broadcast(timestep, src=0)

I think the same random seed generator can be created for each sp rank and ensure that the random seed is the same for each step, so that the same random input can also be obtained.

        gen = torch.Generator()
        gen.manual_seed(seed + step_id + sp_group_id)
        timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,), generator=gen)

step3. Next, it is necessary to ensure that other random numbers in the calculation are fixed. For instance, if your algorithm needs to add other random frames to the input or add other random noises, these random seeds need to be fixed.

step4. After the input is fully aligned, you can test whether the results of the loss calculation by fwd are equal. If they are not completely consistent, then you have to print intermediate result tensor and compared step by step. This is a troublesome task. You can try to reduce the number of model layers to simplify the problem (e.g. Change the 24 blocks of the dit of wan2.2ti2v 5B to 1).

step5. When the loss of fwd can be aligned, add the calculation of bwd. torch's autograd automatically ensures the gradient calculation and update process. The problem I encountered was caused by the operator not providing a bwd, which led to a computational graph broken. Then, check if there is an activation value gradient of 0 / NAN / inf:

        for name, param in model.named_parameters():
            if param.requires_grad:
                if param.grad is not None:
                    grad = param.grad.clone().detach()
                    if torch.allclose(grad, torch.zeros_like(grad), atol=1e-8):
                    ...

Then looking for the location where the exception occurred If there are outliers (tensor.register_hook(grad_hook) is called when bwd calculates grad). The order of gradient calculation is opposite to fwd. If 0 / NAN/inf appears from a certain step, then this operator may not provide bwd or the calculated grad is incorrect.

    norm2_x = self.norm2(x_after_cross_attn)
    if norm2_x.requires_grad:
    norm2_x.retain_grad()
    def grad_hook(grad):
        #print grad
    norm2_x.register_hook(grad_hook)

This is my work and welcome to discuss and add more.

… and DeepSpeed zero2, where train_micro_batch_size_per_gpu must be specified in the json.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants